{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c9ff4273-cc93-4e03-b9ae-0a411fc4f450",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import cv2\n",
    "import json\n",
    "import numpy as np\n",
    "from PIL import Image, ImageDraw\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from torchvision import transforms, models\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "3c64a8fb-8467-4a25-9c20-257aa4ebd001",
   "metadata": {},
   "outputs": [],
   "source": [
    "def json_to_mask(json_path, size):\n",
    "    \"\"\"\n",
    "    将 LabelMe JSON 文件转换为掩码图像。\n",
    "    \n",
    "    :param json_path: JSON 文件路径\n",
    "    :param size: 图像尺寸 (宽, 高)\n",
    "    :return: 掩码的 NumPy 数组\n",
    "    \"\"\"\n",
    "    with open(json_path, 'r') as f:\n",
    "        data = json.load(f)\n",
    "    \n",
    "    # 创建空掩码，初始为 0（背景）\n",
    "    mask = Image.new('L', size, 0)\n",
    "    draw = ImageDraw.Draw(mask)\n",
    "    \n",
    "    # 遍历 shapes，绘制多边形\n",
    "    for shape in data['shapes']:\n",
    "        label = shape['label']\n",
    "        if label == \"road\":  # 只处理道路标签\n",
    "            points = shape['points']\n",
    "            draw.polygon(points, outline=1, fill=1)  # 填充道路区域为 1\n",
    "    \n",
    "    # 转换为 NumPy 数组\n",
    "    mask = np.array(mask, dtype=np.uint8)\n",
    "    return mask\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "50943ab3-44e0-4569-8bd3-d6184875de0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def parse_json_to_mask(json_path, image_size):\n",
    "    \"\"\"\n",
    "    将 LabelMe 标注的 JSON 文件解析为二值化掩码。\n",
    "    :param json_path: JSON 文件路径\n",
    "    :param image_size: 图像尺寸 (width, height)\n",
    "    :return: 二值化掩码 (numpy 数组)\n",
    "    \"\"\"\n",
    "    import json\n",
    "\n",
    "    with open(json_path, 'r') as f:\n",
    "        data = json.load(f)\n",
    "\n",
    "    mask = np.zeros((image_size[1], image_size[0]), dtype=np.uint8)  # 初始化掩码\n",
    "    draw = ImageDraw.Draw(Image.fromarray(mask))\n",
    "\n",
    "    for shape in data['shapes']:\n",
    "        if shape['label'] == 'road':  # 只处理标签为 'road' 的区域\n",
    "            points = shape['points']\n",
    "            polygon = [(int(x), int(y)) for x, y in points]\n",
    "            draw.polygon(polygon, outline=1, fill=1)  # 绘制多边形\n",
    "\n",
    "    return np.array(mask, dtype=np.uint8)  # 返回二值掩码\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "a34fa75c-dd77-4b98-be03-14966350ad52",
   "metadata": {},
   "outputs": [],
   "source": [
    "def preprocess_image_and_mask(image_path, json_path, size=(256, 256)):\n",
    "    # 图像预处理\n",
    "    transform_image = transforms.Compose([\n",
    "        transforms.Resize(size),\n",
    "        transforms.ToTensor()\n",
    "    ])\n",
    "    image = Image.open(image_path).convert(\"RGB\")\n",
    "    image = transform_image(image)\n",
    "\n",
    "    # 解析 JSON 并生成掩码\n",
    "    original_image = Image.open(image_path)\n",
    "    mask = parse_json_to_mask(json_path, original_image.size)\n",
    "\n",
    "    # 掩码预处理\n",
    "    transform_mask = transforms.Compose([\n",
    "        transforms.Resize(size),\n",
    "        transforms.ToTensor()\n",
    "    ])\n",
    "    mask = transform_mask(Image.fromarray(mask))\n",
    "\n",
    "    return image, mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "8585073a-8abd-4d56-ad31-c9ec38f5a5d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "class RoadSegmentationDataset(Dataset):\n",
    "    def __init__(self, image_paths, json_paths, size=(256, 256)):\n",
    "        \"\"\"\n",
    "        自定义数据集，用于加载图像和 JSON 掩码。\n",
    "\n",
    "        :param image_paths: 图像路径列表\n",
    "        :param json_paths: JSON 掩码路径列表\n",
    "        :param size: 图像和掩码的目标尺寸\n",
    "        \"\"\"\n",
    "        assert len(image_paths) == len(json_paths), \"图像和 JSON 的数量必须一致！\"\n",
    "        self.image_paths = image_paths\n",
    "        self.json_paths = json_paths\n",
    "        self.size = size\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.image_paths)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        image_path = self.image_paths[idx]\n",
    "        json_path = self.json_paths[idx]\n",
    "        image, mask = preprocess_image_and_mask(image_path, json_path, self.size)\n",
    "        return {'image': image, 'mask': mask}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "0a1110b1-8f39-4c9e-bd63-e2ba154834a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "class RoadSegmentationModel(nn.Module):\n",
    "    def __init__(self, num_classes=2):\n",
    "        super(RoadSegmentationModel, self).__init__()\n",
    "        resnet = models.resnet18(pretrained=True)\n",
    "        self.encoder = nn.Sequential(*list(resnet.children())[:-2])  # 提取高层特征\n",
    "        \n",
    "        self.decoder = nn.Sequential(\n",
    "            nn.Conv2d(512, 256, kernel_size=3, padding=1),\n",
    "            nn.ReLU(),\n",
    "            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),\n",
    "            nn.Conv2d(256, 128, kernel_size=3, padding=1),\n",
    "            nn.ReLU(),\n",
    "            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),\n",
    "            nn.Conv2d(128, 64, kernel_size=3, padding=1),\n",
    "            nn.ReLU(),\n",
    "            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),\n",
    "            nn.Conv2d(64, 32, kernel_size=3, padding=1),\n",
    "            nn.ReLU(),\n",
    "            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),\n",
    "            nn.Conv2d(32, num_classes, kernel_size=1)  # 输出类别\n",
    "        )\n",
    "    \n",
    "    def forward(self, x):\n",
    "        x = self.encoder(x)\n",
    "        x = self.decoder(x)\n",
    "        return x\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "9d10728f-6431-4d36-b69d-bafa37a5776e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_model(model, dataloader, criterion, optimizer, num_epochs=10):\n",
    "    model.train()\n",
    "    for epoch in range(num_epochs):\n",
    "        total_loss = 0.0\n",
    "        for batch in dataloader:\n",
    "            images, masks = batch['image'], batch['mask']\n",
    "            if torch.cuda.is_available():\n",
    "                images, masks = images.cuda(), masks.cuda()\n",
    "            \n",
    "            # 前向传播\n",
    "            outputs = model(images)\n",
    "            loss = criterion(outputs, masks.squeeze(1))  # 去掉多余的通道\n",
    "            \n",
    "            # 反向传播\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            \n",
    "            total_loss += loss.item()\n",
    "        \n",
    "        print(f\"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss / len(dataloader)}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "909180b2-b4ea-4bf3-a405-1c1f4c0adf43",
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_and_visualize(model, image_path, json_path, size=(256, 256)):\n",
    "    model.eval()\n",
    "    image, _ = preprocess_image_and_mask(image_path, json_path, size)\n",
    "    if torch.cuda.is_available():\n",
    "        image = image.cuda()\n",
    "    with torch.no_grad():\n",
    "        output = model(image.unsqueeze(0))  # 添加 batch 维度\n",
    "        prediction = torch.argmax(output.squeeze(), dim=0).cpu().numpy()\n",
    "    \n",
    "    plt.figure(figsize=(10, 5))\n",
    "    plt.subplot(1, 2, 1)\n",
    "    plt.title(\"Original Image\")\n",
    "    plt.imshow(image.permute(1, 2, 0).cpu().numpy())\n",
    "    plt.subplot(1, 2, 2)\n",
    "    plt.title(\"Predicted Mask\")\n",
    "    plt.imshow(prediction, cmap=\"gray\")\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "fddfad1b-05ce-45ce-86ca-a5c0e6400e7d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([3, 256, 256]) torch.Size([1, 256, 256])\n"
     ]
    }
   ],
   "source": [
    "# 配置路径\n",
    "target = 1\n",
    "data_name = ['0618', '0854', '1066'][target - 1]\n",
    "image_path = f'../input_data/{data_name}.png'\n",
    "json_path = f'../input_data/{data_name}.json'\n",
    "image, mask = preprocess_image_and_mask(image_path, json_path, size=(256, 256))\n",
    "print(image.shape, mask.shape)  # 应该返回 torch.Tensor 的尺寸"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "b7376af8-5fe3-4c8b-a6ef-a050b741a5cd",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/crisi/miniconda3/envs/d2l/lib/python3.9/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
      "  warnings.warn(\n",
      "/home/crisi/miniconda3/envs/d2l/lib/python3.9/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.\n",
      "  warnings.warn(msg)\n"
     ]
    },
    {
     "ename": "RuntimeError",
     "evalue": "input and target batch or spatial sizes don't match: target [1, 256, 256], input [1, 2, 128, 128]",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[10], line 15\u001b[0m\n\u001b[1;32m     12\u001b[0m optimizer \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39moptim\u001b[38;5;241m.\u001b[39mAdam(model\u001b[38;5;241m.\u001b[39mparameters(), lr\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.001\u001b[39m)\n\u001b[1;32m     14\u001b[0m \u001b[38;5;66;03m# 训练模型\u001b[39;00m\n\u001b[0;32m---> 15\u001b[0m \u001b[43mtrain_model\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdataloader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcriterion\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_epochs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m10\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m     17\u001b[0m \u001b[38;5;66;03m# 测试和可视化\u001b[39;00m\n\u001b[1;32m     18\u001b[0m test_and_visualize(model, image_path, json_path)\n",
      "Cell \u001b[0;32mIn[7], line 12\u001b[0m, in \u001b[0;36mtrain_model\u001b[0;34m(model, dataloader, criterion, optimizer, num_epochs)\u001b[0m\n\u001b[1;32m     10\u001b[0m \u001b[38;5;66;03m# 前向传播\u001b[39;00m\n\u001b[1;32m     11\u001b[0m outputs \u001b[38;5;241m=\u001b[39m model(images)\n\u001b[0;32m---> 12\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[43mcriterion\u001b[49m\u001b[43m(\u001b[49m\u001b[43moutputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmasks\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msqueeze\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m  \u001b[38;5;66;03m# 去掉多余的通道\u001b[39;00m\n\u001b[1;32m     14\u001b[0m \u001b[38;5;66;03m# 反向传播\u001b[39;00m\n\u001b[1;32m     15\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n",
      "File \u001b[0;32m~/miniconda3/envs/d2l/lib/python3.9/site-packages/torch/nn/modules/module.py:1553\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1551\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1552\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1553\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/envs/d2l/lib/python3.9/site-packages/torch/nn/modules/module.py:1562\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1557\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1558\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1559\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1560\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1561\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1562\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1564\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1565\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/miniconda3/envs/d2l/lib/python3.9/site-packages/torch/nn/modules/loss.py:1188\u001b[0m, in \u001b[0;36mCrossEntropyLoss.forward\u001b[0;34m(self, input, target)\u001b[0m\n\u001b[1;32m   1187\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor, target: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m-> 1188\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcross_entropy\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtarget\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mweight\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1189\u001b[0m \u001b[43m                           \u001b[49m\u001b[43mignore_index\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mignore_index\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreduction\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreduction\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1190\u001b[0m \u001b[43m                           \u001b[49m\u001b[43mlabel_smoothing\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlabel_smoothing\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/envs/d2l/lib/python3.9/site-packages/torch/nn/functional.py:3104\u001b[0m, in \u001b[0;36mcross_entropy\u001b[0;34m(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)\u001b[0m\n\u001b[1;32m   3102\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m size_average \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mor\u001b[39;00m reduce \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m   3103\u001b[0m     reduction \u001b[38;5;241m=\u001b[39m _Reduction\u001b[38;5;241m.\u001b[39mlegacy_get_string(size_average, reduce)\n\u001b[0;32m-> 3104\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_C\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_nn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcross_entropy_loss\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtarget\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m_Reduction\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_enum\u001b[49m\u001b[43m(\u001b[49m\u001b[43mreduction\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mignore_index\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlabel_smoothing\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[0;31mRuntimeError\u001b[0m: input and target batch or spatial sizes don't match: target [1, 256, 256], input [1, 2, 128, 128]"
     ]
    }
   ],
   "source": [
    "# 加载数据集\n",
    "dataset = RoadSegmentationDataset([image_path], [json_path])\n",
    "dataloader = DataLoader(dataset, batch_size=4, shuffle=False)\n",
    "\n",
    "# 初始化模型\n",
    "model = RoadSegmentationModel(num_classes=2)\n",
    "if torch.cuda.is_available():\n",
    "    model = model.cuda()\n",
    "\n",
    "# 定义损失函数和优化器\n",
    "criterion = nn.CrossEntropyLoss()  # 二分类交叉熵损失\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
    "\n",
    "# 训练模型\n",
    "train_model(model, dataloader, criterion, optimizer, num_epochs=10)\n",
    "\n",
    "# 测试和可视化\n",
    "test_and_visualize(model, image_path, json_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6aebbdf-ef95-4fe8-af1f-e54f27953273",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.20"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
